package com.xx.websocket.server;

import com.alibaba.fastjson2.JSON;
import com.xx.web.domain.entity.EarlyWarning;
import com.xx.web.service.IEarlyWarningService;
import com.xx.websocket.config.JsonEncoder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author fmy
 * @version 1.0
 * @ClassName: WebSocketServer
 * @Description:
 * @Date: 2023/4/11 14:14
 * @since JDK 1.8
 */
@Slf4j
@Component
@ServerEndpoint(value = "/wb/early/{place}/{key}", encoders = JsonEncoder.class)
public class EarlyPushServer {

    private static AtomicInteger onlineCount = new AtomicInteger(0);

    //private static CopyOnWriteArraySet<EarlyPushServer> webSocketSet = new CopyOnWriteArraySet<>();
    //public static Map<String, Session> sessionPool = new HashMap<String, Session>();
    public static Map<String, CopyOnWriteArraySet<EarlyPushServer>> webSocketMap = new HashMap<>();
    private Session session;
    private static IEarlyWarningService earlyWarningService;

    @Autowired
    public void setEarlyWarningService(IEarlyWarningService earlyWarningService) {
        EarlyPushServer.earlyWarningService = earlyWarningService;
    }

    @OnOpen
    public void onOpen(Session session, @PathParam("place") String place, @PathParam("key") String key) {
        this.session = session;
        //sessionPool.put(key, session);
        //加入set
        System.out.println("连接place: " + place);
        CopyOnWriteArraySet<EarlyPushServer> earlyPushServers = webSocketMap.get(place);
        if (earlyPushServers == null) {
            CopyOnWriteArraySet<EarlyPushServer> webSocketSet = new CopyOnWriteArraySet<>();
            webSocketSet.add(this);
            webSocketMap.put(place, webSocketSet);
        } else {

            earlyPushServers.add(this);
        }
        addOnlineCount();
        log.info("有新连接加入！当前在线人数为" + getOnlineCount() + " session: " + session.getId());
        List<EarlyWarning> earlyWarnings = earlyWarningService.pushList();
        if (earlyWarnings.size() > 0) {
            for (EarlyWarning earlyWarning : earlyWarnings) {
                sendSyncMessage(JSON.toJSONString(earlyWarning));
            }
        }
    }

    @OnClose
    public void onClose(Session session, @PathParam("place") String place, @PathParam("key") String key) {
        webSocketMap.get(place).remove(this);
        subOnlineCount();
        log.info("有一连接关闭！当前在线人数为" + getOnlineCount());
    }

    public void optClose(Session session) {
        if (session.isOpen()) {
            try {
                CloseReason closeReason = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "鉴权失败!");
                session.close(closeReason);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 收到客户端消息后调用的方法
     * {"sendType":"message","data":""} message 表示发送的是消息 heart表示心跳 则data必须是ping
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session, @PathParam("key") String key) {
        if (!StringUtils.isEmpty(message)) {
            if ("ping".equals(message)) {
                sendSyncMessage("pong");
            }
        }
    }

    /**
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("发生错误");
        error.printStackTrace();
    }

    /**
     * 指定发送自定义消息
     */
    public static void sendInfo(Session session, String message) {
        if (session == null) {
            return;
        }
        try {
            System.out.println("发送消息:" + message);
            session.getBasicRemote().sendText(message);
        } catch (Exception e) {
            log.error("服务端给客户端群发消息失败: ", e);
        }

    }

    public static void sendInfo2All(String place, String message) {

        //for (EarlyPushServer earlyPushServer : webSocketSet) {
        CopyOnWriteArraySet<EarlyPushServer> pushServers = webSocketMap.get(place);
        if (!CollectionUtils.isEmpty(pushServers)) {
            for (EarlyPushServer earlyPushServer : pushServers) {
                try {
                    earlyPushServer.session.getBasicRemote().sendText(message);
                } catch (Exception e) {
                    log.error("服务端给客户端群发消息失败: ", e);
                }
            }
        }
    }


    /**
     * 发送信息
     *
     * @param message
     */
    private void sendSyncMessage(String message) {
        try {
            this.session.getBasicRemote().sendText(message);
        } catch (Exception e) {
            log.error("服务端给客户端发送消息失败: ", e);
        }
    }

    public void sendASyncMessage(byte[] message) {
        try {
            this.session.getAsyncRemote().sendBinary(ByteBuffer.wrap(message));
        } catch (Exception e) {
            log.error("服务端给客户端发送消息失败: ", e);
        }
    }

    /**
     * 发送信息
     *
     * @param message
     */
    private void sendASyncMessage(String message) {
        try {
            this.session.getAsyncRemote().sendText(message);
        } catch (Exception e) {
            log.error("服务端给客户端发送消息失败: ", e);
        }
    }
    // 此为广播消息
/*
    public static void sendAllMessage(String message) {
        for (EarlyPushServer webSocket : webSocketSet) {
//            System.out.println("【websocket消息】广播消息:" + message);
            try {
                webSocket.session.getAsyncRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
    public static void sendAllMessage(ByteBuffer byteBuffer) {
//        System.out.println(new Date());
        for (EarlyPushServer webSocket : webSocketSet) {
            try {
//                synchronized (webSocket.session){
//                 webSocket.session.getAsyncRemote().sendBinary(byteBuffer);
                    webSocket.session.getBasicRemote().sendBinary(byteBuffer);
//                }

            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
    // 此为广播消息
    public static void sendAllObjMessage(Object message) {
        System.out.println(webSocketSet.size());
        for (EarlyPushServer webSocket : webSocketSet) {
            try {
                webSocket.session.getAsyncRemote().sendObject(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
*/

    // 此为单点消息 (发送对象)
    public void sendObjMessage(Object message) {
        if (this.session != null) {
            try {
                this.session.getAsyncRemote().sendObject(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public AtomicInteger getOnlineCount() {
        return onlineCount;
    }

    private void addOnlineCount() {
        onlineCount.incrementAndGet();
    }

    public void subOnlineCount() {
        onlineCount.decrementAndGet();
    }
}
